"""
This is the description of the deep NN currently being used.
It is a small CNN for the features with an GRU encoding of the LTL task.
The features and LTL are preprocessed by utils.format.get_obss_preprocessor(...) function:
    - In that function, I transformed the LTL tuple representation into a text representation:
    - Input:  ('until',('not','a'),('and', 'b', ('until',('not','c'),'d')))
    - output: ['until', 'not', 'a', 'and', 'b', 'until', 'not', 'c', 'd']
Each of those tokens get a one-hot embedding representation by the utils.format.Vocabulary class.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, Normal
from detector_model import BaseDetector
import torch_ac

from gym.spaces import Box, Discrete

# from gnns.graphs.GCN import *
# from gnns.graphs.GNN import GNNMaker

from env_model import getEnvModel
from policy_network import PolicyNetwork

# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        m.weight.data.normal_(0, 1)
        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
        if m.bias is not None:
            m.bias.data.fill_(0)


class ACModel(nn.Module, torch_ac.ACModel):
    def __init__(self, env, obs_space, action_space, dumb_ac: bool, no_rm: bool):
        super().__init__()

        # Decide which components are enabled
        self.action_space = action_space
        self.dumb_ac = dumb_ac

        self.env_model = getEnvModel(env, obs_space, no_rm)
        self.embedding_size = self.env_model.size()
        print("Model: embedding size:", self.embedding_size)
        
        if self.dumb_ac:
            # Define actor's model
            self.actor = PolicyNetwork(self.embedding_size, self.action_space)

            # Define critic's model
            self.critic = nn.Sequential(
                nn.Linear(self.embedding_size, 1)
            )
        else:
            # Define actor's model
            self.actor = PolicyNetwork(self.embedding_size, self.action_space, hiddens=[64, 64, 64], activation=nn.ReLU())

            # Define critic's model
            self.critic = nn.Sequential(
                nn.Linear(self.embedding_size, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, 1)
            )

        # Initialize parameters correctly
        self.apply(init_params)

    def forward(self, obs, use_rm_belief):
        embedding = self.env_model(obs, use_rm_belief=use_rm_belief)

        # Actor
        dist = self.actor(embedding)

        # Critic
        x = self.critic(embedding)
        value = x.squeeze(1)

        return dist, value


class LSTMModel(nn.Module):
    def __init__(self, obs_size, word_embedding_size=32, hidden_dim=32, text_embedding_size=32):
        super().__init__()
        # For all our experiments we want the embedding to be a fixed size so we can "transfer". 
        self.word_embedding = nn.Embedding(obs_size, word_embedding_size)
        self.lstm = nn.LSTM(word_embedding_size, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
        self.output_layer = nn.Linear(2*hidden_dim, text_embedding_size)

    def forward(self, text):
        hidden, _ = self.lstm(self.word_embedding(text))
        return self.output_layer(hidden[:, -1, :])


class GRUModel(nn.Module):
    def __init__(self, obs_size, word_embedding_size=32, hidden_dim=32, text_embedding_size=32):
        super().__init__()
        self.word_embedding = nn.Embedding(obs_size, word_embedding_size)
        self.gru = nn.GRU(word_embedding_size, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
        self.output_layer = nn.Linear(2*hidden_dim, text_embedding_size)

    def forward(self, text):
        hidden, _ = self.gru(self.word_embedding(text))
        return self.output_layer(hidden[:, -1, :])



